"""
STAR dataset, no fixed epsilon method
"""

import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns
from sklearn.cluster import KMeans
from sklearn.preprocessing import StandardScaler, LabelEncoder
from sklearn.ensemble import RandomForestRegressor
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import cross_val_predict
import warnings
from datetime import datetime

warnings.filterwarnings('ignore')


def theoretical_bound(m, beta, N, delta, OPT):
    """
    Compute theoretical performance bound for allocation algorithms.

    Parameters:
    -----------
    m : int
        Sample size
    beta : float
        Bound parameter (0.5 for FullCATE, 1.0 for ALLOC)
    N : int
        Number of groups
    delta : float
        Confidence parameter
    OPT : float
        Optimal allocation value

    Returns:
    --------
    float
        Theoretical lower bound on performance
    """
    term = N * np.log(2 * N / delta) / m
    if term >= 1:
        return 0
    return (1 - term**beta) * OPT


class SampleSizeAnalyzer:
    """
    Analyzer for sample size requirements in CATE allocation problems.

    This class implements methods to evaluate how allocation performance
    varies with sample size across different group construction methods.
    """

    def __init__(self, random_seed=42):
        """
        Initialize the analyzer.

        Parameters:
        -----------
        random_seed : int, default=42
            Random seed for reproducibility
        """
        self.random_seed = random_seed
        np.random.seed(random_seed)
        print(f"Sample Size Analyzer initialized with seed {random_seed}")

    def process_star_data(self, df, outcome_col=None):
        """
        Process STAR dataset for analysis.

        Parameters:
        -----------
        df : pandas.DataFrame
            Raw STAR dataset
        outcome_col : str, optional
            Outcome column name (if different from default)

        Returns:
        --------
        pandas.DataFrame
            Processed dataset ready for analysis
        """
        print(f"Processing STAR data with {len(df)} observations")

        df_processed = df.copy()

        # Validate required columns
        required_cols = ['gkschid', 'gkclasstype']
        missing = [col for col in required_cols if col not in df_processed.columns]
        if missing:
            raise ValueError(f"Missing required columns: {missing}")

        # Filter class types
        print(f"Original class type distribution: {df_processed['gkclasstype'].value_counts().to_dict()}")
        df_processed = df_processed[df_processed['gkclasstype'] != 'REGULAR + AIDE CLASS']
        print(f"After excluding aide classes: {len(df_processed)} observations")

        # Create treatment variable
        treatment_map = {'SMALL CLASS': 1, 'REGULAR CLASS': 0}
        df_processed['treatment'] = df_processed['gkclasstype'].map(treatment_map)

        # Create composite outcome
        test_components = ['gktreadss', 'gktmathss', 'gktlistss', 'gkwordskillss']
        available_components = [col for col in test_components if col in df_processed.columns]

        if not available_components:
            raise ValueError("No test score components found")

        # Remove observations with missing test scores
        initial_size = len(df_processed)
        df_processed = df_processed.dropna(subset=available_components)
        print(f"Dropped {initial_size - len(df_processed)} rows due to missing test scores")

        df_processed['total_score'] = df_processed[available_components].sum(axis=1)
        df_processed['outcome'] = df_processed['total_score']

        # Final data cleaning
        initial_size = len(df_processed)
        df_processed = df_processed.dropna(subset=['treatment', 'gkschid'])
        final_size = len(df_processed)

        if initial_size != final_size:
            print(f"Dropped {initial_size - final_size} rows due to missing treatment/school data")

        print(f"Final dataset: {final_size} students")
        print(f"Treatment distribution: {df_processed['treatment'].value_counts().to_dict()}")

        return df_processed

    def create_school_groups(self, df, min_size=6):
        """
        Create groups based on school identifiers.

        Parameters:
        -----------
        df : pandas.DataFrame
            Processed dataset
        min_size : int, default=6
            Minimum group size threshold

        Returns:
        --------
        list
            List of group dictionaries with balanced treatment assignment
        """
        print(f"Creating school-based groups (min_size={min_size})")

        groups = []
        for school_id in df['gkschid'].unique():
            indices = df[df['gkschid'] == school_id].index.tolist()
            if len(indices) >= min_size:
                groups.append({
                    'id': f'school_{school_id}',
                    'indices': indices,
                    'type': 'school'
                })

        print(f"Raw groups created: {len(groups)}")
        balanced_groups = self._ensure_balance_and_compute_cate(df, groups)
        print(f"Balanced groups after filtering: {len(balanced_groups)}")

        return balanced_groups

    def create_causal_forest_groups(self, df, n_groups=30, min_size=6):
        """
        Create groups using machine learning-based treatment effect prediction.

        Uses separate Random Forest models for treated and control outcomes,
        then clusters based on predicted treatment effects and covariates.

        Parameters:
        -----------
        df : pandas.DataFrame
            Processed dataset
        n_groups : int, default=30
            Target number of groups
        min_size : int, default=6
            Minimum group size threshold

        Returns:
        --------
        list
            List of group dictionaries
        """
        print(f"Creating machine learning-based groups (target: {n_groups})")

        # Select features (exclude treatment, outcome, and test scores)
        feature_cols = [col for col in df.columns
                       if col not in ['treatment', 'outcome', 'total_score']
                       and not col.startswith('gkt')]

        X = df[feature_cols].copy()

        # Data preprocessing
        for col in X.columns:
            if X[col].dtype == 'object' or X[col].dtype.name == 'category':
                X[col] = LabelEncoder().fit_transform(X[col].astype(str))
            elif X[col].dtype == 'bool':
                X[col] = X[col].astype(int)

        # Handle missing values
        for col in X.columns:
            if X[col].isna().any():
                if X[col].dtype in ['int64', 'float64']:
                    X[col] = X[col].fillna(X[col].median())
                else:
                    X[col] = X[col].fillna(X[col].mode()[0] if len(X[col].mode()) > 0 else 0)

        # Train separate outcome models
        treated_mask = df['treatment'] == 1
        control_mask = df['treatment'] == 0

        if treated_mask.sum() == 0 or control_mask.sum() == 0:
            print("Insufficient treated or control observations")
            return []

        rf_treated = RandomForestRegressor(n_estimators=100, random_state=self.random_seed)
        rf_control = RandomForestRegressor(n_estimators=100, random_state=self.random_seed)

        rf_treated.fit(X[treated_mask], df.loc[treated_mask, 'outcome'])
        rf_control.fit(X[control_mask], df.loc[control_mask, 'outcome'])

        # Predict treatment effects and cluster
        pred_cate = rf_treated.predict(X) - rf_control.predict(X)
        cluster_features = np.column_stack([X.values, pred_cate.reshape(-1, 1)])
        cluster_features = StandardScaler().fit_transform(cluster_features)

        labels = KMeans(n_clusters=n_groups, random_state=self.random_seed).fit_predict(cluster_features)

        groups = []
        for i in range(n_groups):
            indices = df.index[labels == i].tolist()
            if len(indices) >= min_size:
                groups.append({
                    'id': f'ml_cluster_{i}',
                    'indices': indices,
                    'type': 'ml_cluster'
                })

        print(f"Created {len(groups)} machine learning-based groups")
        balanced_groups = self._ensure_balance_and_compute_cate(df, groups)
        return balanced_groups

    def create_propensity_groups(self, df, n_groups=50, min_size=6):
        """
        Create groups based on propensity score stratification.

        Parameters:
        -----------
        df : pandas.DataFrame
            Processed dataset
        n_groups : int, default=50
            Target number of strata
        min_size : int, default=6
            Minimum group size threshold

        Returns:
        --------
        list
            List of group dictionaries
        """
        print(f"Creating propensity score groups (target: {n_groups})")

        feature_cols = [col for col in df.columns
                       if col not in ['treatment', 'outcome', 'total_score']]

        X = df[feature_cols].copy()

        # Data preprocessing
        for col in X.columns:
            if X[col].dtype == 'object' or X[col].dtype.name == 'category':
                X[col] = LabelEncoder().fit_transform(X[col].astype(str))
            elif X[col].dtype == 'bool':
                X[col] = X[col].astype(int)

        # Handle missing values
        for col in X.columns:
            if X[col].isna().any():
                if X[col].dtype in ['int64', 'float64']:
                    X[col] = X[col].fillna(X[col].median())
                else:
                    X[col] = X[col].fillna(X[col].mode()[0] if len(X[col].mode()) > 0 else 0)

        # Estimate propensity scores using cross-validation
        prop_scores = cross_val_predict(
            LogisticRegression(random_state=self.random_seed),
            X, df['treatment'], method='predict_proba', cv=5
        )[:, 1]

        # Create quantile-based strata
        quantiles = np.linspace(0, 1, n_groups + 1)
        bins = np.digitize(prop_scores, np.quantile(prop_scores, quantiles)) - 1

        groups = []
        for i in range(n_groups):
            indices = df.index[bins == i].tolist()
            if len(indices) >= min_size:
                groups.append({
                    'id': f'propensity_{i}',
                    'indices': indices,
                    'type': 'propensity'
                })

        print(f"Created {len(groups)} propensity score groups")
        balanced_groups = self._ensure_balance_and_compute_cate(df, groups)
        return balanced_groups

    def create_performance_groups(self, df, n_groups=50, min_size=6):
        """
        Create groups based on baseline academic performance percentiles.

        Parameters:
        -----------
        df : pandas.DataFrame
            Processed dataset
        n_groups : int, default=50
            Target number of groups
        min_size : int, default=6
            Minimum group size threshold

        Returns:
        --------
        list
            List of group dictionaries
        """
        print(f"Creating performance groups (target: {n_groups})")

        # Identify baseline score columns
        score_cols = [col for col in df.columns if col.startswith('gkt') and 'ss' in col]
        if not score_cols:
            print("No baseline scores found")
            return []

        baseline_score = df[score_cols].fillna(df[score_cols].mean()).mean(axis=1)

        # Create percentile-based groups
        percentiles = np.linspace(0, 100, n_groups + 1)
        cuts = np.percentile(baseline_score, percentiles)
        bins = np.digitize(baseline_score, cuts) - 1

        groups = []
        for i in range(n_groups):
            indices = df.index[bins == i].tolist()
            if len(indices) >= min_size:
                groups.append({
                    'id': f'performance_{i}',
                    'indices': indices,
                    'type': 'performance'
                })

        print(f"Created {len(groups)} performance groups")
        balanced_groups = self._ensure_balance_and_compute_cate(df, groups)
        return balanced_groups

    def create_demographics_groups(self, df, min_size=6, feature_cols=None):
        """
        Create groups based on demographic characteristic combinations.

        Parameters:
        -----------
        df : pandas.DataFrame
            Processed dataset
        min_size : int, default=6
            Minimum group size threshold
        feature_cols : list, optional
            List of demographic features to use

        Returns:
        --------
        list
            List of group dictionaries
        """
        print(f"Creating demographics groups")

        if feature_cols is None:
            potential_features = ['gkfreelunch', 'race', 'gender', 'birthyear']
        else:
            potential_features = feature_cols

        # Identify available features
        available_features = []
        for col in potential_features:
            if col in df.columns and df[col].notna().sum() > 0:
                available_features.append(col)

        if len(available_features) == 0:
            print("No demographic variables found, using school grouping")
            return self.create_school_groups(df, min_size)

        print(f"Using features: {available_features}")

        # Remove observations with missing demographic data
        df_clean = df[available_features].dropna()
        print(f"After removing missing values: {len(df_clean)}/{len(df)} students")

        if len(df_clean) == 0:
            return self.create_school_groups(df, min_size)

        # Create groups based on unique combinations
        unique_combinations = df_clean.drop_duplicates()
        print(f"Found {len(unique_combinations)} unique combinations")

        groups = []
        for combo_idx, (idx, combo) in enumerate(unique_combinations.iterrows()):
            mask = pd.Series(True, index=df.index)
            combo_description = []

            for feature in available_features:
                mask = mask & (df[feature] == combo[feature])
                combo_description.append(f"{feature}={combo[feature]}")

            indices = df[mask].index.tolist()
            combo_id = "_".join(combo_description)

            if len(indices) >= min_size:
                groups.append({
                    'id': combo_id,
                    'indices': indices,
                    'type': 'demographics'
                })

        print(f"Created {len(groups)} demographic groups")
        balanced_groups = self._ensure_balance_and_compute_cate(df, groups)
        return balanced_groups

    def _ensure_balance_and_compute_cate(self, df, groups):
        """
        Filter groups for treatment balance and compute conditional average treatment effects.

        Parameters:
        -----------
        df : pandas.DataFrame
            Dataset
        groups : list
            List of candidate groups

        Returns:
        --------
        list
            List of balanced groups with CATE estimates
        """
        balanced_groups = []

        for group in groups:
            group_df = df.loc[group['indices']]

            treatment_rate = group_df['treatment'].mean()
            n_treated = group_df['treatment'].sum()
            n_control = len(group_df) - n_treated

            # Apply balance and minimum size requirements
            if not (0.15 <= treatment_rate <= 0.85 and n_treated >= 3 and n_control >= 3):
                continue

            # Compute CATE as difference in means
            treated_outcomes = group_df[group_df['treatment'] == 1]['outcome']
            control_outcomes = group_df[group_df['treatment'] == 0]['outcome']
            cate = treated_outcomes.mean() - control_outcomes.mean()

            balanced_groups.append({
                'id': group['id'],
                'indices': group['indices'],
                'size': len(group_df),
                'treatment_rate': treatment_rate,
                'n_treated': int(n_treated),
                'n_control': int(n_control),
                'cate': cate,
                'type': group['type']
            })

        return balanced_groups

    def normalize_cates(self, groups):
        """
        Normalize CATE values to [0,1] interval.

        Parameters:
        -----------
        groups : list
            List of groups with CATE estimates

        Returns:
        --------
        list
            Groups with normalized CATE values
        """
        cates = [g['cate'] for g in groups]
        min_cate, max_cate = min(cates), max(cates)

        if max_cate > min_cate:
            for group in groups:
                group['normalized_cate'] = (group['cate'] - min_cate) / (max_cate - min_cate)
        else:
            for group in groups:
                group['normalized_cate'] = 0.5

        print(f"CATE normalization: [{min_cate:.3f}, {max_cate:.3f}] → [0, 1]")
        return groups

    def simulate_sampling_trial(self, groups, sample_size, trial_seed):
        """
        Simulate one trial of the bandit sampling process.

        Parameters:
        -----------
        groups : list
            List of groups with normalized CATEs
        sample_size : int
            Total sample size for the trial
        trial_seed : int
            Trial-specific random seed

        Returns:
        --------
        tuple
            (tau_estimates, sample_counts) arrays
        """
        np.random.seed(self.random_seed + trial_seed)

        n_groups = len(groups)
        tau_true = np.array([g['normalized_cate'] for g in groups])

        # Initialize estimates
        tau_estimates = np.zeros(n_groups)
        sample_counts = np.zeros(n_groups)

        # Perform sampling
        for _ in range(sample_size):
            # Choose group uniformly at random
            group_idx = np.random.randint(n_groups)

            # Sample from Bernoulli distribution with success probability tau_true
            sample = np.random.binomial(1, tau_true[group_idx])

            # Update running average
            sample_counts[group_idx] += 1
            if sample_counts[group_idx] == 1:
                tau_estimates[group_idx] = sample
            else:
                tau_estimates[group_idx] = ((sample_counts[group_idx] - 1) * tau_estimates[group_idx] + sample) / sample_counts[group_idx]

        # Set unsampled groups to zero estimate
        tau_estimates[sample_counts == 0] = 0

        return tau_estimates, sample_counts

    def analyze_sample_size_performance(self, groups, sample_sizes, budget_percentages, n_trials=50):
        """
        Analyze allocation performance as a function of sample size.

        Parameters:
        -----------
        groups : list
            List of groups with normalized CATEs
        sample_sizes : list
            Sample sizes to evaluate
        budget_percentages : list
            Budget constraints as fractions of total groups
        n_trials : int, default=50
            Number of trials per sample size

        Returns:
        --------
        tuple
            (results, optimal_values) dictionaries
        """
        print(f"Analyzing sample size performance with {len(groups)} groups")

        n_groups = len(groups)
        tau_true = np.array([g['normalized_cate'] for g in groups])

        # Convert budget percentages to group counts
        budgets = [max(1, int(p * n_groups)) for p in budget_percentages]
        print(f"Budget percentages {budget_percentages} → K values {budgets}")

        # Calculate optimal values for each budget
        optimal_values = {}
        for i, K in enumerate(budgets):
            optimal_indices = np.argsort(tau_true)[-K:]
            optimal_values[budget_percentages[i]] = np.sum(tau_true[optimal_indices])

        # Initialize results storage
        results = {bp: {'sample_sizes': [], 'values': [], 'stds': []} for bp in budget_percentages}

        for sample_size in sample_sizes:
            print(f"  Sample size {sample_size}...")

            # Store trial results for each budget
            budget_trial_values = {bp: [] for bp in budget_percentages}

            for trial in range(n_trials):
                tau_estimates, sample_counts = self.simulate_sampling_trial(groups, sample_size, trial)

                # Evaluate each budget level
                for i, K in enumerate(budgets):
                    bp = budget_percentages[i]

                    # Select top K groups based on estimates
                    selected_indices = np.argsort(tau_estimates)[-K:]

                    # Compute realized value using true CATE values
                    realized_value = np.sum(tau_true[selected_indices])
                    budget_trial_values[bp].append(realized_value)

            # Store summary statistics
            for bp in budget_percentages:
                results[bp]['sample_sizes'].append(sample_size)
                results[bp]['values'].append(np.mean(budget_trial_values[bp]))
                results[bp]['stds'].append(np.std(budget_trial_values[bp]))

        return results, optimal_values

    def plot_sample_size_analysis(self, results, optimal_values, method_name, budget_percentages, n_groups):
        """
        Create visualization of sample size analysis with theoretical bounds.

        Parameters:
        -----------
        results : dict
            Results from sample size analysis
        optimal_values : dict
            Optimal values for each budget
        method_name : str
            Name of the grouping method
        budget_percentages : list
            Budget levels analyzed
        n_groups : int
            Total number of groups
        """
        fig, axes = plt.subplots(2, 3, figsize=(18, 12))
        axes = axes.flatten()

        delta = 0.05  # Confidence parameter for theoretical bounds

        print(f"\nPlotting {method_name} (N={n_groups})")
        print("="*60)

        for i, bp in enumerate(budget_percentages):
            ax = axes[i]

            # Extract data for this budget level
            sample_sizes = results[bp]['sample_sizes']
            values = results[bp]['values']
            stds = results[bp]['stds']
            optimal_val = optimal_values[bp]

            # Normalize by optimal value
            values_norm = np.array(values) / optimal_val
            stds_norm = np.array(stds) / optimal_val

            # Plot empirical performance
            ax.errorbar(sample_sizes, values_norm, yerr=stds_norm,
                      marker='o', capsize=5, capthick=3, linewidth=6, markersize=8,
                      label='Empirical data', color='blue', alpha=0.8)

            # Plot optimal performance line
            ax.axhline(y=1.0, color='black', linestyle=':', linewidth=2,
                      label='Optimal (1.0)', alpha=0.8)

            # Create theoretical bound curves
            m_smooth = np.linspace(min(sample_sizes), max(sample_sizes), 200)

            # Compute and plot theoretical bounds
            ref_curve_05 = [theoretical_bound(m, 0.5, n_groups, delta, optimal_val) / optimal_val for m in m_smooth]
            ref_curve_10 = [theoretical_bound(m, 1.0, n_groups, delta, optimal_val) / optimal_val for m in m_smooth]

            ax.plot(m_smooth, ref_curve_05, 'red', linestyle=(0, (3, 2)), linewidth=6,
                  label='FullCATE (β=0.5)', alpha=0.8)
            ax.plot(m_smooth, ref_curve_10, 'green', linestyle=(0, (3, 1, 1, 1)), linewidth=6,
                  label='ALLOC (β=1.0)', alpha=0.8)

            # Formatting
            ax.set_xlabel('Sample size', fontsize=23)
            ax.set_ylabel('Normalized allocation value', fontsize=23)
            ax.set_title(f'Budget = {bp*100:.0f}% (K={max(1, int(bp * n_groups))})',
                        fontsize=24, fontweight='bold')

            ax.legend(fontsize=21, framealpha=0.9)
            ax.grid(True, alpha=0.4, linewidth=1)
            ax.tick_params(axis='both', which='major', labelsize=16, width=1.5, length=5)

            # Set axis limits
            y_min = 0.2
            y_max = 1.05
            ax.set_ylim(y_min, y_max)

            for spine in ax.spines.values():
                spine.set_linewidth(1.5)

        plt.suptitle(f'{method_name} (N={n_groups})', fontsize=24, fontweight='bold')
        plt.tight_layout()

        # Save figure
        clean_name = method_name.replace(' ', '_').replace('(', '').replace(')', '').replace('-', '_')
        pdf_filename = f"{clean_name}_N{n_groups}_sample_size_analysis.pdf"
        plt.savefig(pdf_filename, format='pdf', dpi=300, bbox_inches='tight')
        print(f"Saved plot as: {pdf_filename}")

        plt.show()
        print(f"Plot complete for {method_name}")


def run_sample_size_analysis(df_star, sample_size_range=None, budget_percentages=None, n_trials=50):
    """
    Execute comprehensive sample size analysis across all grouping methods.

    Parameters:
    -----------
    df_star : pandas.DataFrame
        STAR dataset
    sample_size_range : list, optional
        Sample sizes to evaluate
    budget_percentages : list, optional
        Budget constraints to analyze
    n_trials : int, default=50
        Number of trials per configuration

    Returns:
    --------
    dict
        Comprehensive results across all methods
    """

    if sample_size_range is None:
        sample_size_range = [100, 250, 500, 750, 1000, 1200, 1500, 2000, 5000, 10000, 20000]

    if budget_percentages is None:
        budget_percentages = [0.1, 0.2, 0.3, 0.5, 0.7, 0.9]

    print("SAMPLE SIZE ANALYSIS - EMPIRICAL VS THEORETICAL BOUNDS")
    print(f"Sample sizes: {sample_size_range}")
    print(f"Budget percentages: {budget_percentages}")
    print(f"Trials per sample size: {n_trials}")
    print("="*80)

    # Define grouping methods for analysis
    methods = [
        ('School Groups', lambda analyzer, df: analyzer.create_school_groups(df, min_size=6)),
        ('Demographics', lambda analyzer, df: analyzer.create_demographics_groups(df,
                                               feature_cols=['gkfreelunch', 'race', 'gender'], min_size=6)),
        ('ML Clustering (30)', lambda analyzer, df: analyzer.create_causal_forest_groups(df, n_groups=30, min_size=6)),
        ('ML Clustering (50)', lambda analyzer, df: analyzer.create_causal_forest_groups(df, n_groups=50, min_size=6)),
        ('Propensity Score', lambda analyzer, df: analyzer.create_propensity_groups(df, n_groups=50, min_size=6)),
        ('Performance Groups', lambda analyzer, df: analyzer.create_performance_groups(df, n_groups=50, min_size=6))
    ]

    all_results = {}

    for method_name, method_func in methods:
        print(f"\n{'='*80}")
        print(f"ANALYZING METHOD: {method_name}")
        print("="*80)

        try:
            # Initialize analyzer and process data
            analyzer = SampleSizeAnalyzer()
            df_processed = analyzer.process_star_data(df_star)

            # Create groups using current method
            groups = method_func(analyzer, df_processed)

            if len(groups) < 10:
                print(f"Insufficient groups ({len(groups)}) for {method_name} - skipping")
                continue

            groups = analyzer.normalize_cates(groups)

            # Perform sample size analysis
            results, optimal_values = analyzer.analyze_sample_size_performance(
                groups, sample_size_range, budget_percentages, n_trials
            )

            # Store results
            all_results[method_name] = {
                'results': results,
                'optimal_values': optimal_values,
                'n_groups': len(groups)
            }

            # Generate visualization
            print(f"Creating plots for {method_name}...")
            analyzer.plot_sample_size_analysis(
                results, optimal_values, method_name, budget_percentages, len(groups)
            )

            # Print summary statistics
            print(f"\nSummary for {method_name}:")
            print(f"Number of groups: {len(groups)}")
            print("Optimal values by budget:")
            for bp in budget_percentages:
                print(f"  {bp*100:.0f}%: {optimal_values[bp]:.3f}")

        except Exception as e:
            print(f"Error with {method_name}: {e}")
            continue

    return all_results


if __name__ == "__main__":
    # Load STAR dataset
    df_star = pd.read_spss('STAR_Students.sav')

    # Configure analysis parameters
    sample_sizes = [100, 250, 500, 750, 1000, 1200, 1500, 2000, 5000, 10000, 20000]
    budget_percentages = [0.1, 0.2, 0.3, 0.5, 0.7, 0.9]

    # Execute comprehensive sample size analysis
    results = run_sample_size_analysis(
        df_star,
        sample_size_range=sample_sizes,
        budget_percentages=budget_percentages,
        n_trials=50
    )

    print("\n" + "="*80)
    print("SAMPLE SIZE ANALYSIS COMPLETE")
    print("="*80)

    # Print overall summary
    print(f"Analysis completed for {len(results)} grouping methods:")
    for method_name, method_data in results.items():
        n_groups = method_data['n_groups']
        print(f"  {method_name}: {n_groups} groups")

    print(f"\nAnalysis parameters:")
    print(f"  Sample sizes: {len(sample_sizes)} levels")
    print(f"  Budget constraints: {len(budget_percentages)} levels")
    print(f"  Trials: 50 per configuration")
    print(f"  Total configurations analyzed: {len(sample_sizes) * len(budget_percentages) * len(results)}")